from gymnasium.envs.registration import register


##################################################
# Block Bandit
register(
     "BlockBandit2ArmCoupledEasy-v0",
     entry_point="environments.bandit.bandit:BlockBandit2ArmCoupledEasy",
     max_episode_steps=1000
)

register(
     "BlockBandit2ArmCoupledEasy-v1",
     entry_point="environments.bandit.bandit:BlockBandit2ArmCoupledEasy",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'num_blocks': 10, 
          'block_len_var': 10,
          'prob_pool': [0.8, 0.2]
     }
)

register(
     "BlockBandit2ArmCoupledEasy-v2",
     entry_point="environments.bandit.bandit:BlockBandit2ArmCoupledEasy",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'num_blocks': 5, 
          'block_len_var': 10,
          'prob_pool': [0.8, 0.2]
     }
)

register(
     "BlockBandit2ArmCoupledEasy-v3",
     entry_point="environments.bandit.bandit:BlockBandit2ArmCoupledEasy",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'num_blocks': 20, 
          'block_len_var': 10,
          'prob_pool': [0.8, 0.2]
     }
)

register(
     "BlockBandit2ArmCoupledMultipleProb-v0",
     entry_point="environments.bandit.bandit:BlockBandit2ArmCoupledMultipleProb",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'num_blocks': 10, 
          'block_len_var': 10,
          'high_reward_prob_pool': {0.9, 0.7, 0.5}
     }
)

register(
     "BlockBandit2ArmCoupledMultipleProb-v1",  # over-riding reset method
     entry_point="environments.bandit.bandit:BlockBandit2ArmCoupledMultipleProb",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'num_blocks': 10, 
          'block_len_var': 10,
          'high_reward_prob_pool': {0.9, 0.7, 0.5}
     }
)

register(
     "BlockBandit2ArmCoupledEasyMixedBlockLength-v0",
     entry_point="environments.bandit.bandit:BlockBandit2ArmCoupledEasyMixedBlockLength",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'num_blocks_pool': [5, 20],
          'block_len_var': 10,
          'prob_pool': [0.8, 0.2]
     }
)

register(
     "BlockBandit2ArmMixedCoupledMultipleProbAndIndependent-v0",
     entry_point="environments.bandit.bandit:BlockBandit2ArmMixedCoupledMultipleProbAndIndependent",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'num_blocks': 10, 
          'block_len_var': 10,
          'high_reward_prob_pool': {0.9, 0.7, 0.5}
     }
)


##################################################
# Baited block bandit
register(
     "BaitedBlockBandit2ArmCoupledEasy-v0",
     entry_point="environments.bandit.bandit:BaitedBlockBandit2ArmCoupledEasy",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'num_blocks': 10, 
          'block_len_var': 10,
          'prob_pool': [0.8, 0.2]
     }
)

register(
     "BaitedBlockBandit2ArmCoupledMultipleProb-v0",
     entry_point="environments.bandit.bandit:BaitedBlockBandit2ArmCoupledMultipleProb",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'num_blocks': 10, 
          'block_len_var': 10,
          'high_reward_prob_pool': {0.9, 0.7, 0.5}
     }
)


##################################################
# RandomWalk bandit
register(
     "RandomWalkBandit2ArmGaussian-v0",
     entry_point="environments.bandit.bandit:RandomWalkBandit2ArmGaussian",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'random_walk_drift_rate': 0.05,
          'random_walk_start': 0.5
     }
)

register(
     "RandomWalkBandit2ArmGaussian-v1",
     entry_point="environments.bandit.bandit:RandomWalkBandit2ArmGaussian",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'random_walk_drift_rate': 0.2,
          'random_walk_start': 0.5
     }
)

register(
     "RandomWalkBandit2ArmGaussian-v2",
     entry_point="environments.bandit.bandit:RandomWalkBandit2ArmGaussian",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'random_walk_drift_rate': 0.0125,
          'random_walk_start': 0.5
     }
)

register(
     "RandomWalkBandit2ArmGaussianMixedDriftRate-v0",
     entry_point="environments.bandit.bandit:RandomWalkBandit2ArmGaussianMixedDriftRate",
     max_episode_steps=1000,
     kwargs={
          'total_trials': 400, 
          'random_walk_drift_rate_pool': [0.2, 0.0125],
          'random_walk_start': 0.5
     }
)


##################################################
# Timed block bandit
register(
     id="TimedBlockBandit2ArmCoupledEasy-v0",
     entry_point="environments.bandit.bandit:TimedBlockBandit2ArmCoupledEasy",
     max_episode_steps=20000,
)

register(
     id="TimedBlockBandit2ArmCoupledMultipleProb-v0",
     entry_point="environments.bandit.bandit:TimedBlockBandit2ArmCoupledMultipleProb",
     max_episode_steps=20000,
)

register(
     id="TimedBlockBandit2ArmCoupledMultipleProb-v1",
     entry_point="environments.bandit.bandit:TimedBlockBandit2ArmCoupledMultipleProb",
     max_episode_steps=20000,
     kwargs={
          'total_trials': 400, 
          'num_blocks': 10, 
          'block_len_var': 10,
          'trial_len_range': [4,5,6,7],
          'high_reward_prob_pool': {0.9, 0.7, 0.5}
     }
)


##################################################
# Dynamic foraging tasks
register(
     "CoupledBlockDF-v0",
     entry_point="environments.dynamic_foraging.dynamic_foraging:CoupledBlockTask",
     max_episode_steps=1000
)

register(
     "CoupledBlockDF-v1",
     entry_point="environments.dynamic_foraging.dynamic_foraging:CoupledBlockTask",
     max_episode_steps=400,
     kwargs={
          'num_trials': 400
     }
)

register(
     "CoupledBlockDF-v2",
     entry_point="environments.dynamic_foraging.dynamic_foraging:CoupledBlockTask",
     max_episode_steps=1000,
     kwargs={
          'num_trials': 1000, 
          'block_beta': 5, 
     }
)

register(
     "CoupledBlockDF-v3",
     entry_point="environments.dynamic_foraging.dynamic_foraging:CoupledBlockTask",
     max_episode_steps=500,
     kwargs={
          'num_trials': 500, 
          'p_reward_pairs': [
               [0.45 / 4 * 1, 0.45 / 4 * 3],  # 1:3
               [0.45 / 7 * 1, 0.45 / 7 * 6],  # 1:6
               [0.05, 0.40],  # 1:8
          ]
     }
)

register(
     "CoupledBlockDF-v4",
     entry_point="environments.dynamic_foraging.dynamic_foraging:CoupledBlockTask",
     max_episode_steps=500,
     kwargs={
          'num_trials': 500, 
          'p_reward_pairs': [
               [0.5, 0.5],  # 1:1
               [1.0 / 4 * 1, 1.0 / 4 * 3],  # 1:3
               [1.0 / 7 * 1, 1.0 / 7 * 6],  # 1:6
               [1.0 / 9 * 1, 1.0 / 9 * 8],  # 1:8
          ]
     }
)

register(
     "UncoupledBlockDF-v0",
     entry_point="environments.dynamic_foraging.dynamic_foraging:UncoupledBlockTask",
     max_episode_steps=1000
)

register(
     "RandomWalkDF-v0",
     entry_point="environments.dynamic_foraging.dynamic_foraging:RandomWalkTask",
     max_episode_steps=1000
)

